#!/usr/bin/env python3

#####################################################################
#
# pg-drop is a tool for show/clear ingress pg dropped packet stats.
#
#####################################################################
import _pickle as pickle
import argparse
import os
import sys
from collections import OrderedDict

from natsort import natsorted
from tabulate import tabulate

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector

except KeyError:
    pass

from swsscommon.swsscommon import ConfigDBConnector, SonicV2Connector

STATUS_NA = 'N/A'

COUNTER_TABLE_PREFIX = "COUNTERS:"

COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"
COUNTERS_PG_NAME_MAP = "COUNTERS_PG_NAME_MAP"
COUNTERS_PG_PORT_MAP = "COUNTERS_PG_PORT_MAP"
COUNTERS_PG_INDEX_MAP = "COUNTERS_PG_INDEX_MAP"

def get_dropstat_dir():
    dropstat_dir_prefix = '/tmp/dropstat'
    return "{}-{}/".format(dropstat_dir_prefix, os.getuid())

class PgDropStat(object):

    def __init__(self):
        self.counters_db = SonicV2Connector(host='127.0.0.1')
        self.counters_db.connect(self.counters_db.COUNTERS_DB)

        self.configdb = ConfigDBConnector()
        self.configdb.connect()

        dropstat_dir = get_dropstat_dir()
        self.port_drop_stats_file = os.path.join(dropstat_dir, 'pg_drop_stats')

        def get_port_id(oid):
            """
                Get port ID using object ID
            """
            port_id = self.counters_db.get(self.counters_db.COUNTERS_DB, COUNTERS_PG_PORT_MAP, oid)
            if port_id is None:
                print("Port is not available for oid '{}'".format(oid), file=sys.stderr)
                sys.exit(1)
            return port_id

        # Get all ports
        self.counter_port_name_map = self.counters_db.get_all(self.counters_db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP)
        if self.counter_port_name_map is None:
            print("COUNTERS_PORT_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        self.port_pg_map = {}
        self.port_name_map = {}

        for port in self.counter_port_name_map:
            self.port_pg_map[port] = {}
            self.port_name_map[self.counter_port_name_map[port]] = port

        # Get PGs for each port
        counter_pg_name_map = self.counters_db.get_all(self.counters_db.COUNTERS_DB, COUNTERS_PG_NAME_MAP)
        if counter_pg_name_map is None:
            print("COUNTERS_PG_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        for pg in counter_pg_name_map:
            port = self.port_name_map[get_port_id(counter_pg_name_map[pg])]
            self.port_pg_map[port][pg] = counter_pg_name_map[pg]

        self.pg_drop_types = {
            "pg_drop"       : {"message" : "Ingress PG dropped packets:",
                               "obj_map" : self.port_pg_map,
                               "idx_func": self.get_pg_index,
                               "counter_name" : "SAI_INGRESS_PRIORITY_GROUP_STAT_DROPPED_PACKETS",
                               "header_prefix": "PG"},
        }

    def get_pg_index(self, oid):
        """
            return PG index (0-7)

            oid - object ID for entry in redis
        """
        pg_index = self.counters_db.get(self.counters_db.COUNTERS_DB, COUNTERS_PG_INDEX_MAP, oid)
        if pg_index is None:
            print("Priority group index is not available for oid '{}'".format(table_id), file=sys.stderr)
            sys.exit(1)
        return pg_index

    def build_header(self, pg_drop_type):
        """
            Construct header for table with PG counters
        """
        if pg_drop_type is None:
            print("Header info is not available!", file=sys.stderr)
            sys.exit(1)

        self.header_list = ['Port']
        header_map = pg_drop_type["obj_map"]
        single_key = list(header_map.keys())[0]
        header_len = len(header_map[single_key])
        min_idx = sys.maxsize

        for name, counter_oid in header_map[single_key].items():
            curr_idx = int(pg_drop_type["idx_func"](counter_oid))
            min_idx = min(min_idx, curr_idx)

        self.min_idx = min_idx
        self.header_list += ["{}{}".format(pg_drop_type["header_prefix"], idx) for idx in range(self.min_idx, self.min_idx + header_len)]

    def get_counters(self, table_prefix, port_obj, idx_func, counter_name):
        """
            Get the counters of a specific table.
        """
        port_drop_ckpt = {}
        # Grab the latest clear checkpoint, if it exists
        if os.path.isfile(self.port_drop_stats_file):
            port_drop_ckpt = pickle.load(open(self.port_drop_stats_file, 'rb'))

        # Header list contains the port name followed by the PGs. Fields is used to populate the pg values
        fields = ["0"]* (len(self.header_list) - 1)

        for name, obj_id in port_obj.items():
            full_table_id = table_prefix + obj_id
            old_collected_data = port_drop_ckpt.get(name,{})[full_table_id] if len(port_drop_ckpt) > 0 else 0
            idx = int(idx_func(obj_id))
            pos = idx - self.min_idx
            counter_data = self.counters_db.get(self.counters_db.COUNTERS_DB, full_table_id, counter_name)
            if counter_data is None:
                fields[pos] = STATUS_NA
            elif fields[pos] != STATUS_NA:
                fields[pos] = str(int(counter_data) -  old_collected_data)
        return fields

    def print_all_stat(self, table_prefix, key):
        """
            Print table that show stats per PG
        """
        table = []
        type = self.pg_drop_types[key]
        self.build_header(type)
        # Get stat for each port
        for port in natsorted(self.counter_port_name_map):
            row_data = list()
            data = self.get_counters(table_prefix, type["obj_map"][port], type["idx_func"], type["counter_name"])
            row_data.append(port)
            row_data.extend(data)
            table.append(tuple(row_data))

        print(type["message"])
        print(tabulate(table, self.header_list, tablefmt='simple', stralign='right'))

    def get_counts(self, counters, oid):
            """
                Get the PG drop counts for an individual counter.
            """
            counts = {}
            table_id = COUNTER_TABLE_PREFIX + oid
            for counter in counters:
                counter_data = self.counters_db.get(self.counters_db.COUNTERS_DB, table_id, counter)
                if counter_data is None:
                    counts[table_id] = 0
                else:
                    counts[table_id] = int(counter_data)
            return counts

    def get_counts_table(self, counters, object_table):
        """
            Returns a dictionary containing a mapping from an object (like a port)
            to its PG drop counts. Counts are contained in a dictionary that maps
            counter oid to its counts.
        """
        counter_object_name_map = self.counters_db.get_all(self.counters_db.COUNTERS_DB, object_table)
        current_stat_dict = OrderedDict()

        if counter_object_name_map is None:
            return current_stat_dict

        for obj in natsorted(counter_object_name_map):
            current_stat_dict[obj] = self.get_counts(counters, counter_object_name_map[obj])
        return current_stat_dict

    def clear_drop_counts(self):
        """
            Clears the current PG drop counter.
        """

        counter_pg_drop_array = [ "SAI_INGRESS_PRIORITY_GROUP_STAT_DROPPED_PACKETS"]
        try:
            pickle.dump(self.get_counts_table(
                counter_pg_drop_array,
                COUNTERS_PG_NAME_MAP),
                open(self.port_drop_stats_file, 'wb+'))
        except IOError as e:
            print(e)
            sys.exit(e.errno)
        print("Cleared PG drop counter")

    def check_if_stats_enabled(self):
        pg_drop_info = self.configdb.get_entry('FLEX_COUNTER_TABLE', 'PG_DROP')
        if pg_drop_info:
            status = pg_drop_info.get("FLEX_COUNTER_STATUS", 'disable')
            if status == "disable":
                print("Warning: PG counters are disabled. Use 'counterpoll pg-drop enable' to enable polling")
                sys.exit(0)

def main():
    parser = argparse.ArgumentParser(description='Display PG drop counter',
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog="""
Examples:
pg-drop -c show
pg-drop -c clear
""")

    parser.add_argument('-c', '--command', type=str, help='Desired action to perform')

    args = parser.parse_args()
    command = args.command

    dropstat_dir = get_dropstat_dir()
    # Create the directory to hold clear results
    if not os.path.exists(dropstat_dir):
        try:
            os.makedirs(dropstat_dir)
        except IOError as e:
            print(e)
            sys.exit(e.errno)

    pgdropstat = PgDropStat()

    if command == 'clear':
        pgdropstat.clear_drop_counts()
    elif command == 'show':
        pgdropstat.check_if_stats_enabled()
        pgdropstat.print_all_stat(COUNTER_TABLE_PREFIX, "pg_drop" )
    else:
        print("Command not recognized")
    sys.exit(0)


if __name__ == "__main__":
    main()
